{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import cv2\n",
    "import torch\n",
    "from time import time\n",
    "import os\n",
    "import matplotlib.pyplot as plt\n",
    "import torch.nn as nn\n",
    "import torchvision\n",
    "import torch.nn.functional as F\n",
    "from torch.utils import data\n",
    "from datagen import Datagen\n",
    "from image_utils import *\n",
    "from v2_93 import *\n",
    "import math\n",
    "import datetime\n",
    "\n",
    "device = 'cuda'\n",
    "ckpt_dir = 'E:/ModelCkpts/StabNet-multigrid-future_frames'\n",
    "starting_epoch = 0\n",
    "H,W,C = shape = (256,256,3)\n",
    "batch_size = 4\n",
    "EPOCHS = 10\n",
    "grid_h, grid_w = 15,15"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from models import transformer\n",
    "class StabNet(nn.Module):\n",
    "    def __init__(self,trainable_layers = 10):\n",
    "        super(StabNet, self).__init__()\n",
    "        # Load the pre-trained ResNet model\n",
    "        vgg19 = torchvision.models.vgg19(weights='IMAGENET1K_V1')\n",
    "        # Extract conv1 pretrained weights for RGB input\n",
    "        rgb_weights = vgg19.features[0].weight.clone() #torch.Size([64, 3, 3, 3])\n",
    "        # Calculate the average across the RGB channels\n",
    "        tiled_rgb_weights = rgb_weights.repeat(1,5,1,1) \n",
    "        # Change size of the first layer from 3 to 9 channels\n",
    "        vgg19.features[0] = nn.Conv2d(15,64, kernel_size=3, stride=1, padding=1, bias=False)\n",
    "        # set new weights\n",
    "        vgg19.features[0].weight = nn.Parameter(tiled_rgb_weights)\n",
    "        # Determine the total number of layers in the model\n",
    "        total_layers = sum(1 for _ in vgg19.parameters())\n",
    "        # Freeze the layers except the last 10\n",
    "        for idx, param in enumerate(vgg19.parameters()):\n",
    "            if idx > total_layers - trainable_layers:\n",
    "                param.requires_grad = True\n",
    "            else:\n",
    "                param.requires_grad = False\n",
    "        # Remove the last layer of ResNet\n",
    "        self.encoder = nn.Sequential(*list(vgg19.children())[0][:-1])\n",
    "        self.regressor = nn.Sequential(nn.Linear(512,2048),\n",
    "                                       nn.ReLU(),\n",
    "                                       nn.Linear(2048,1024),\n",
    "                                       nn.ReLU(),\n",
    "                                       nn.Linear(1024,512),\n",
    "                                       nn.ReLU(),\n",
    "                                       nn.Linear(512, ((grid_h + 1) * (grid_w + 1) * 2)))\n",
    "        #self.regressor[-1].bias.data.fill_(0)\n",
    "        total_resnet_params = sum(p.numel() for p in self.encoder.parameters() if p.requires_grad)\n",
    "        total_regressor_params = sum(p.numel() for p in self.regressor.parameters() if p.requires_grad)\n",
    "        print(\"Total Trainable mobilenet Parameters: \", total_resnet_params)\n",
    "        print(\"Total Trainable regressor Parameters: \", total_regressor_params)\n",
    "        print(\"Total Trainable parameters:\",total_regressor_params + total_resnet_params)\n",
    "    \n",
    "    def forward(self, x_tensor):\n",
    "        x_batch_size = x_tensor.size()[0]\n",
    "        x = x_tensor[:, :3, :, :]\n",
    "\n",
    "        # summary 1, dismiss now\n",
    "        x_tensor = self.encoder(x_tensor)\n",
    "        x_tensor = torch.mean(x_tensor, dim=[2, 3])\n",
    "        x = self.regressor(x_tensor)\n",
    "        x = x.view(x_batch_size,grid_h + 1,grid_w + 1,2)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total Trainable mobilenet Parameters:  2360320\n",
      "Total Trainable regressor Parameters:  3936256\n",
      "Total Trainable parameters: 6296576\n"
     ]
    }
   ],
   "source": [
    "stabnet = StabNet(trainable_layers=10).to(device).train()\n",
    "optimizer = torch.optim.Adam(stabnet.parameters(),lr = 2e-5)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 16, 16, 2])"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "out = stabnet(torch.randn(1,15,256,256).float().to(device))\n",
    "out.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loaded weights from E:/ModelCkpts/StabNet-multigrid-future_frames\\stabnet_2023-11-01_23-17-37.pth\n",
      "Reduced learning rate to 5.000000000000001e-07\n"
     ]
    }
   ],
   "source": [
    "ckpts = os.listdir(ckpt_dir)\n",
    "if ckpts:\n",
    "    ckpts = sorted(ckpts, key=lambda x: datetime.datetime.strptime(x.split('_')[2].split('.')[0], \"%H-%M-%S\"), reverse=True)\n",
    "    # Get the filename of the latest checkpoint\n",
    "    latest = os.path.join(ckpt_dir, ckpts[0])\n",
    "    # Load the latest checkpoint\n",
    "    state = torch.load(latest)\n",
    "    stabnet.load_state_dict(state['model'])\n",
    "    optimizer.load_state_dict(state['optimizer'])\n",
    "    starting_epoch = state['epoch'] + 1\n",
    "    optimizer.param_groups[0]['lr'] *= 0.1\n",
    "    print('Loaded weights from', latest)\n",
    "    print('Reduced learning rate to', optimizer.param_groups[0]['lr'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_warp(net_out,img):\n",
    "    '''\n",
    "    Inputs:\n",
    "        net_out: torch.Size([batch_size,grid_h +1 ,grid_w +1,2])\n",
    "        img: image to warp\n",
    "    '''\n",
    "    grid_y, grid_x = torch.meshgrid(torch.linspace(-1,1, grid_h + 1),\n",
    "                                    torch.linspace(-1,1, grid_h + 1),\n",
    "                                    indexing='ij')\n",
    "    src_grid = torch.stack([grid_x,grid_y],dim = -1).unsqueeze(0).repeat(batch_size,1,1,1).to(device)\n",
    "    new_grid = src_grid + net_out\n",
    "    grid_upscaled = F.interpolate(new_grid.permute(0,-1,1,2),size = (height,width), mode = 'bilinear',align_corners= True)\n",
    "    warped = F.grid_sample(img, grid_upscaled.permute(0,2,3,1),align_corners=True)\n",
    "    return warped ,grid_upscaled.permute(0,2,3,1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def temp_loss(warped0,warped1, flow):\n",
    "    #prev  warped1\n",
    "    #curr  warped0\n",
    "    temp = dense_warp(warped1, flow)\n",
    "    return F.l1_loss(warped0,temp)\n",
    "\n",
    "def shape_loss(net_out):\n",
    "    grid_y, grid_x = torch.meshgrid(torch.linspace(-1,1, grid_h + 1),\n",
    "                                    torch.linspace(-1,1, grid_h + 1),\n",
    "                                    indexing='ij')\n",
    "    grid_tensor = torch.stack([grid_x,grid_y],dim = -1).unsqueeze(0).repeat(batch_size,1,1,1).to(device)\n",
    "    new_grid = grid_tensor + net_out\n",
    "\n",
    "    #Lintra\n",
    "    vt0 = new_grid[:, :-1, 1:, :]\n",
    "    vt0_original = grid_tensor[:, :-1, 1:, :]\n",
    "    vt1 = new_grid[:, :-1, :-1, :]\n",
    "    vt1_original = grid_tensor[:, :-1, :-1, :]\n",
    "    vt = new_grid[:, 1:, :-1, :]\n",
    "    vt_original = grid_tensor[:, 1:, :-1, :]\n",
    "    alpha = vt - vt1\n",
    "    s = torch.norm(vt_original - vt1_original, dim=-1) / torch.norm(vt0_original - vt1_original, dim=-1)\n",
    "    vt01 = vt0 - vt1\n",
    "    beta = s[..., None] * torch.stack([vt01[..., 1], -vt01[..., 0]], dim=-1)\n",
    "    norm = torch.norm(alpha - beta, dim=-1, keepdim=True)\n",
    "    Lintra = torch.sum(norm) / (((grid_h + 1) * (grid_w + 1)) * batch_size)\n",
    "\n",
    "    # Extract the vertices for computation\n",
    "    vt1_vertical = new_grid[:, :-2, :, :]\n",
    "    vt_vertical = new_grid[:, 1:-1, :, :]\n",
    "    vt0_vertical = new_grid[:, 2:, :, :]\n",
    "\n",
    "    vt1_horizontal = new_grid[:, :, :-2, :]\n",
    "    vt_horizontal = new_grid[:, :, 1:-1, :]\n",
    "    vt0_horizontal = new_grid[:, :, 2:, :]\n",
    "\n",
    "    # Compute the differences\n",
    "    vt_diff_vertical = vt1_vertical - vt_vertical\n",
    "    vt_diff_horizontal = vt1_horizontal - vt_horizontal\n",
    "\n",
    "    # Compute Linter for vertical direction\n",
    "    Linter_vertical = torch.mean(torch.norm(vt_diff_vertical - (vt_vertical - vt0_vertical), dim=-1))\n",
    "\n",
    "    # Compute Linter for horizontal direction\n",
    "    Linter_horizontal = torch.mean(torch.norm(vt_diff_horizontal - (vt_horizontal - vt0_horizontal), dim=-1))\n",
    "\n",
    "    # Combine Linter for both directions\n",
    "    Linter = Linter_vertical + Linter_horizontal\n",
    "\n",
    "    # Compute the shape loss\n",
    "    shape_loss = Lintra + 20 * Linter\n",
    "\n",
    "    return shape_loss\n",
    "\n",
    "def feature_loss(features, warp_field):\n",
    "    stable_features = ((features[:, :, 0, :] + 1) / 2) * torch.tensor([255, 255], dtype=torch.float).to(device)\n",
    "    unstable_features = ((features[:, :, 1, :] + 1) / 2) * torch.tensor([255, 255], dtype=torch.float).to(device)\n",
    "    \n",
    "    # Clip the features to the range [0, 255]\n",
    "    stable_features = torch.clamp(stable_features, min=0, max=255)\n",
    "    unstable_features = torch.clamp(unstable_features, min=0, max=255)\n",
    "    \n",
    "    warped_unstable_features = unstable_features + warp_field[:, unstable_features[:, :, 1].long(), unstable_features[:, :, 0].long(), :]\n",
    "    loss = torch.mean(torch.sqrt(torch.sum((stable_features - warped_unstable_features) ** 2,dim = -1)))\n",
    "    \n",
    "    return loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def draw_features(image,features,color):\n",
    "    drawn = image.copy()\n",
    "    for point in features:\n",
    "        x,y = point\n",
    "        cv2.circle(drawn, (int(x),int(y)), 2, color, -1)\n",
    "    return drawn"
   ]
  },
  {
   "attachments": {
    "image.png": {
     "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZoAAABRCAYAAADrXmCgAAAgAElEQVR4nO3dZ2BUxdrA8f8m2d0ku5tGCiSBhBJCIPQmSBEEBPReBUV9RQHFchWwt2tFEUQREeWKBa96EQsioCAoSIdESuiBhPReN2V7Pef9QBEUSEiBBOf3cXN2zuzm7HnOzDwzo5BlWUYQhLpzF7H6jR8JvqcDyf/Opve8jmy/5xVcz/yHPuU/4bhuHFUfPELRuKVM7ReBn+pKV1gQLi+PK10BQWj2yvayy+VN2s9p9PtkAvGeZeR7dcGveC0nejzCwBATeYf70trPG6Xnla6sIFx+ItAIQj05i9M5ziHofRsdfXRUHtxJUicHB5yDmNDeB+/Uzfw2sDdxkT6oRaAR/oZEoBGEenFyNGEj3SPHMKiLP94qPft27MPjUDR3/rMrOh0c2bWBXr0708rHW/zghL8lcd0LVyVZcmAx6NGb7LikRhyGdB8ncW0PevbpTrhWhaczlxOJKm6aOYleIX6o3MdI+Kk3fVr743RLNGZVBKGpEoFGuLpITqzmakrTdrHi/Sd5bVUSOdWuxjtdcQYF1/ajaxsN3p4gpSTys3Is3WN1+KiAshwyu7YjxL6TNftKMNulRquLIDRVItAIVxdLPnu3f8en77zBf77ZRVGFtVFPZ8hJx69/PEEaHzwBg8FM6JSBxAX4ogIwGijKXMXLP7gZ1CMQjbf4yQl/PwqR3ixclVL+xx2PfYg0djZzJw2hfaDyStdIEP62xOOVcHXyABRXuhKCIIAINIIgCEIjE4FGEARBaFQi0AiCIAiNSgQaQRAEoVF5XekKCELT5cJSqcdgdeG+HLmZPgGEBvii9BBZDMLVRQQaQbigUrYvnMH8Hw+RU+U+Z1a/UhdEgEaNV11jguzG6XJit5kxmx0nyx49kzVv30GsX1NeqkbGbihEb9IRHO6HWIi6kcluHFYTRpMTlX8QOnU9rgzZhd1sorrajBNAqcZXF0CgT+OHARFoBOGCwhn91CuUuV7i0zXHKTK4OT3rLGrco7w86Qa6hGrwqksLxFxCZm4aB3atYcWq3WQVVGCvsmCVZJryxDZZkkhd8X88/uX/sWjbw3S+0hW62jn1ZG77hs+XZ9HpsVnc20NXx4IknMY89q9exqcf/0KG5MTSMpY+dz3JnJviGz3YiAmbwtXpxP+449EPkcY0wITN8kQWvT6LrzedoMzsPhUIfOnz6Dxev+c62gf74lmP3i5z+gpmPvcW6w6NZOGuFxgapqVpTC+VcZorqXR6ExTgK55KmzHZXUHqoaOkpGsZcXsvtNVpbPhuIc+tq2bYgzN5Z2z7Rm1FN90WuiDUg9vtRpZl3G43Un2fpYIHMP2V55gwsAMtfDxPzQO1sO/9Oby7cjc5VfZ6LZap6XAbT08bRYjGRIWxKS28aeLgl08x5c115MuAo5qCggIKC/VY3Fe6blc7GbfTQnVZAQUlZeit9Vuvz1luQ6fU0Of2XmgB/GPoN2w0U/vbsZeWYWqQOl+YCDQNRkZy2bFYrdiczX/hRFlyI0lSk+7GOS/JhqGslOPJaZjNVgoyUkjPyaHc7KjfgH7wUJ5483kmDGxHCx+PU8GmgN/efJn31h0h1+CsV4AI6zOC64MVuG0yTaaPwX6Y1Z9tIy+7kAqrjDvlf9w5/h+MHjWTzforXbmrnYuK7F188eI/uWXqI8zcWlCv0lRh4UR07U3kmVdkVKoAAgI74OOtbvRAUP/W8KnBqupKE46zX/dUotb6E6Rt/A9xxckSTlsVhVkHSCqQaNmuLwPbB5z5s+QwYzAaMNtqEYA81Gj8/PDTqBrue5MdmKsMGC12zn4QVWr88ddqUJ9nRNtt0lNud6PwDiDQ1xtVffqGLqfq/Sz79/v8fDwfE354H1zO3OnLuf75j3hoRDxhPvUoO/g6Hp3lRH5xLst/z6bSKiFTwLp5z6D0mM8TN3UjQudVt5Vv/Hoy9ol8zL5KPC4WaBwGisqd+IUE4atUnHUuCUtlCVWeQbTUqWmIxDXHgV9YXa1l5JQJ9PJVQLfpbPrWjxmLAxkYWv/yG4Qs4bRUUlplq8XBHnh6+RIY6o+6yV/OSkJirmfSc/OJ2bIcZ5eovx4iubCaTZgcnvgF6vgjT0DG7bBiNFuQvAMIOt/4i2ylpDQXh8uL0YN6nmzlNKL6BxqnnsxtX/POq0tJMlZhdHig1vjh37YbA+56jJl39KCuw1fNg4TTXMDBLWvZW+FP9+tvYWCk7zlHVKWt479LPmb5xkKMZjdKnQ6dVs3Zmy3Ksgub2YjJvz/3PPkYD4/rQWBDVdGSwa+LFrN09RbSDGZsCm90Oh1d73iC6XePo1+4+i9v8dKq0CeuZHtFAN36DqZXVAt8lM3gkSFwIA8vGcjDjVS8Mnwkj71YReXri1j3ez4mh4Rclc2Pb71BaOBspgyJJdTXow7BJoDeN91T82Gpn3HT/yXz6A8L+b8YDaoz/5IKNrwxkhdbf8yeR65Bo6rvVp4ukpM2Q/Q93Dm01cmXZJmyA7/hNWwhQfUsvaHIDiO5m+cx+Y2NmCorsbh98A8JwOecS/Vkb4PZ6EITdQ/zVz5Nf98LldiEyEb0+Tby9w5n1P3n+bsxi13ffsH7u1ry/MIZDDxzw3BQdGwNn377IxVDZvPB2LZ/Lhi3sYRiRwDKrg8wIpJGV/9Aowql05jpvNOpG98ueI7397dkyAP/5uXbBxJen6fHZsLtKOPwtq/4LcuT7mPu5trIv17BQV0m8PTcnvRo+R4ffF1Gtycf59HJAwg5uxxTLjtXfMBn6f6ER0c0XJAB0MQx/uV36TvkXeYsWEai92genvEv7uzXFr8LjTp7BNDlhrvQbP2Az9YWUH7deEZ1DcenubRsGpGywwRmvuLE/uRifksuxOyWkKsO8OnzL6KYN5d7B3Ug2KcuwaYJcZ8gaaeBa2ZM4JpTl7Qs28lJOUrHm5tKmAGF2p/2/5jF+r438Omj0/mubDyzV8xiRIuzDpIcmHIS+ea/b7LKtzuxzSHIALLNhEWhx9C3PS1qPrzWJJeJ3GN5GPO9GHZH6wYs+cIa6BH15NOTwsNFm6hejB/59wgySFby965hxa5ynFFjGBhzkbabAiRAERtJWMfIvzwRemrbcO11Y7hj1ECiW/g3QmU9USCjULRg8IChDI2/SJA5w5vogRMYzFG2rPuJxMwqXE1l/OAKU3a4i9fn/YthnVqh8TwVVKoO8Mnbb7F8dw5V9mY4vnUWKX8viaX38MDI6FOvyMj2JNZ/P5kRXa5o1c5PBtnDG3X/fsT9+a7soUIb2ZN/3jWVAZ1jCDhvAU2NjL1aT1V5AZqeXRqsV0h2Oyg/cZAcfR4hY0ZxGRozQEMFGtmMsSKf7IIWBHUZQJfwBim1yXOXHmT7uk1k6APpEdf+ohewVJZHSlkBztCWtA8JPtNtJrvsOBwOnBIovEJp2yqC1sGNMA1OrqYgu4gKVxRt20USWttYpoqm35CeyPt/46cN+8g0iHSj03xjJ/Lv1yczNCbkj5Ze2mYWzF7A+qN5GJ3NNdhIFO3diTzj9jOtGWSwH1zPmoeG08VhwOC4aAGXl8uOJfcIe1JUBHXrSMTp12UJyWnG7AQUCjS+flzTOfoiBdVAcuMwV1JaXoWtsX8Gsp2K0gry09TER3ljs9qpb4qR7HZgzN1NSnEOju4T6ekHyC6cTjtWR+MmMDVMoLFVo89I45jen6Bu7fl7xBkH+Sf2cfh4HrqwtkSHXaw9LlFamE5RkYXIllG0CfU587o5bQ+Hjhwg0wKekfF0jomlfWM07c05pB8qQx/cjvCoUC6lzRTQoQ89Wyo4tHs3h7LKcDbPu2ejCO39AC/OvIOe7VugOj0Cn7aGV56fz8/JhZicTWUCpoSlsoiiwkIKCwtPJjJITiyGckqKCyks0WM68481k5oVy4Mj/rgpy7g4nPgd3X2hYP9/+Tr9ynyK83E7bZSkJnFCpWFUXLszr8uOavSHl7M2DfDSoY26gZEx9TiPuYgDK1/nrkfeZkfJpb5bRnLZMJSf/P5LysoxOdy4nVYMlaUUlxRTWmHAerrLwF5BeVoqhw4oMRUkk7D9EOa6Vx3JZaXixG/8un0be539iPc4WY/CrCPs25/AthOWepReswaZg+UwVpCbk0KFLpDx0bVvjEkuKxZjJQYr4KnEW+NHQHPJUnMUkHkgkxNp4cSOir54C0GuoCgrh+LKFkRGRhPiByDhthdyOLmMKm1rons3bnXt+ekc0ZcT1DaKdiEtLu079gunTTt/dEv3kHx0JAPjWhLx1/yBs5xM0a3bXGAFHs1sra/Qa55g3qsOHpv5PQczKnBKMqSt4fV3Q9A+9yDXdwrF94qPbVWxbeFdvPVzGQXlbsZ/vIkXuley84dFfLZuD2nGaO6d9SHTBoehVqjpdPu9hJ+diiRXUaoPI2vldO5+YDFbJl+xD/InMg57BccO70WluZUubU/2B8uSC4O+hMQd+bScUo/SJRcOuwNUPngho3DasFfZMTo4mfHmsOFEhY+6pmxDF5V5u1k+71k+32nEK6IHk2bO40bNEdb98CUrE3JRRAzn/iee4ub4AJQuJ1ZXLom5BylZombarKn16j4z5iWy9uvZfLYJYD0/nvlLNP1umMLjrzZu3lkDBBo3hqp8MlJzCAicTGxk7QZnZLeD0hMbWf75bL7fLmHVRtLj1mm8+eAIQupSK8mJ0+XC5emDT30TbmpBrigktTSf5JY6+kT44Xexg616CtL1FHjpaBcIjqIiilxGCg/+yPJUJzFDejKiUWvrpjD3IBUVnrQd2IZQ/0udd+5HUCs1vn5HSM3NpqyqNxFhF/knOUwUlegxW+vSv6IjLKIF2oZM774MQq99joUvy8yYuZxD2VW4JBnn9v/yrELNwpn3MjgqGJ8rGkCDGDNzM4OufYr+D5cSq97H6v2xjHvgQ268ayMv3fIwi594l97b5jJUqyI86tyBDoVHCP+YvYt/XKHaX5DsxGnOIuWICs9ObdFYiyiySDhNhST99h3LT/ThzboOysgOjGUp7Ni0D3eXmxge7YU2KAo/L4lArRNjSRYH9yaQ4tmfu8fGcfE7n5IWbYfy4BtfEfvVu8z51QdvWypHiWfS84u4PnEpb774LUv/E03se1Ppqo2i/x2z2X5HHev+J/5thzN51nAmz2qY8i5VA8yjMVJdnEN6RhCB47rSrlb/VAlrQRKbdx8id9T37HjJwZbVn7AkaRsHikYw6pITIWRcZQc5mJVHWfh4xra59I9xqSyGCkzVZaANQ+Wn5WIP+I6SXFJKCijMzOfnN4+wQQEg4bAa0Q26kyHBQX9anFDGaXcgKzxRquo4L+NsUhFZx8ooJZoubcMI/HPXnNuK1e2Fl5eS82cwawkK9sVX40FCfiH51Qa6hwVduF4lO1n81hdsTcqrQ2WvY8aChxh7TVSj5/Y3tNAhT/HWU0aefmsNR/NNJ4PNtsU8OzeQT+bcRa8Wmiu8jIuZ3xN+RvLSsTmzDZ9O7ngy+AUM5vpR8N2nezieB0PjGu6Mcl0m/SoUKBSK2l33LgfmzCMkpJuxVizhifFLTp3XidVTRch9d/4xZnNWnZw2KwpfzcWX+nGYqUrdwrIvllPikUjajClcrw6hm5cNZfUuVn61gI9+ddLhOi3X3xBHuxofcN1UG/JJTt5EhaM/hfa2TBwaiRKJoKhu9Or2CatKUsgtg66XJxnssqn3dS9bjJTnZZOp9KdvfHtqN4/LQF7yAXI2FHHtnDbgD8Mmv82wujbHJTOZx/exe1827W4fX8dCLo3JWICxOgeUvfFUnjsn5k+Vo6zwOKUlDobeM4tHp99IrAZw6Dm0fhGJ1WGEB57b7yY7Kzi8JwVJG0p8z5ganpRqJlcVkJ1VgTm8L60jw/40PuPCmvobm81tiO7QnS4XyKv20YWjVPvhqjJgttpwwYV/pK3HMGvRmHrWumZ6vR6Ho36j0n5+fvj6+qJQNERrQ0n0jTN502ri8fkbSCu2IXn54uusxmqVTs74v5KNGmcyieslVH0e4sV7uuF9poVVTWUR4NGGsAZMeJRlF1X5mZTXZi7lGQo8vHwJiYrArxY9E26nhcLjeygPjmX60rVM7QjgxqpP5bdv5pLQts2516ksYSnPZ++GLUTePZkOFytcHUibwY/w369HkLjuO775fg4HIzvjKxXz/WcbqfAdxasfj2dI51b41qYXRTZjKMni+LEIYsdOZNKotqfq5sRlN2Gs9EalCiNAU4uympl6Bxq7UU9+5nHsflH0bP/nZ4e/kt0ObGY9ZZUGzA4rVWVFlPl5o/UPxMfTjcNiwmCw4MQDpVqDX4D25KQ02Y3dYsRosJ5Z4tpH64e/0o0lbz/bt27jp/RQplxXRoVWjVLpwmSyo/T2Revnh5fLjMloxCqp8NHq8PdVITstmKxWJA81XpIdo8MTjU6Hn9oDl82EyWjC6gJPtS9anR++yj/uEgqFJwpFLTp3pFJy04vIc0XSJS6KsNMXkaoF7WNHItm90J3+cctOrAYjpYdW8XuWTFDHKGLNDtT17EYyF6ZzvEJPeGwEbYP/aHK6bVVUm9JZvzgbqasvHTuasbt9UZ9nPEGt0uHlqQaLDZfTWe8MmIbw+qsz2Z+UVK8yHn/qKUaPGY1G01C/biXtbnqS+/fvY87aSqSuE5nz2oMMiNChvMLDNK7jCWzV+zDyjVvp6PnHPB/ZkcnB38GzdXdiWjVUJWXcjmK2fDiFBdsu5X1KfAOG8czSmYwIrulYN3ZbKUcPpKDWPUiPM/MSPfHRhjP42uFU6U4/+p6ctGnIzyTlyCreqB7C16XV2IL98b7Yj0vhhTqoIwPGP06HPofZsHwBX1h13DpiDs/1jCRQo8G7ll31srma8rRj5Gjb8M+xQ/9ILZasGMqLyMzUoRvajsjGmqYkn31/BbxU+Oj8CfBu8tsEODBU5pKeWkFA4P8Re544I7tduF1OJKUPKg9w6lPZ9M07zP7iIBaLE+X0W9nYfziTn3mNYdpMdq9axpIlG0h3aYnuNZb7XniIoVFq5Ko0Er79nMWfbafI04UlIp4hk5/m1Z4WNn73EZ+sPILVDu8+UsmRYf2Jikrlv4s3ET5gHP96+RXiin7ki4XzWFkexw33PcULt3SiYvfnfPbLLqythtHOkMCnR3VMePg5nu/nzfHty1j25XJ2pTvw6zaCCQ88zp29W565qNS+LVBrwsAiIUsnV/Q9309UriokI6OAQp9IhkeGnpMC7fIKwl/lScDpJos5na2fLGbxF1vJ91ThrTlA0bT7mXR3L2r8zV2QndysI5TrA4hu3Y7QwD9+VSVJH7Fg0VI2HvNCtduXTZWPM/2eW+hznpUCZFkCZPD0wKPGp//Lkwzw3gfv1/Ec5/LwaLjRILetmpSf3+X9DSbcXW7n9denMSDCD9WVzgVAIvPIZorUE3jlGv+zlqiRqN6xmlWVvnScMZYuf6qnw2gAXV32nVHgpY5k/NwEGq2PQXJg1x9m33YVPpO60P6spossK1Eo2hDb6vQrDqoLdvDhhOf4vtKOssUWHtp+A48vfoahF7qxyxIuWxXFWQdJ2LiWXw5k4N1zMNeE5VGy5XkeXtOLcWNuYVivGMJbaGp8kLCaK8jOzsUj5J9c2/n0j17GVVVC1oHf2e8fyY2j+3H2YjOS24XL4cDDp76rZ0s4TbkkrfqKjxevJ0NyYg3vTL+JTzLnpi6NHmzqV7rLQEV2GsfTdQRMjaL1n0uT3ViK00hPz0DqcRM9/UEV2pWbHltITMevWfm/o3Sa8yHj2oJkymDv6k3syRjAG4kz8cv4hW8/+ICv54fT5t0R2Lf+wqYtJsZ+tpEpUbms+fI9PvvxC3b1nMeEGbNpFfQhP+z2YOQrb50co7Fk0tlDxer9AGqi+k1k2vMy6lU/U4Gd/IMrWfHxZ6wt8SPkGhdx7Xpxf6syZC8LKQk/kZArcdP8BOZ67eerTxey+ptPadnyBUa3ORlpfHyD8NMG43GiCou+EjOtzzOmIGMsziKnOJ/QsD7EhZ47wBrQodO5c2+0cYx58D4M5jiMkb24+f7+56wecLrMk/dWBbXq7XEXknlYT5m6Df3ahnB2HkD4tc8xzVxJq/JYrhk8kYGtLzzSZLNV4XTZ8GoVhJ9Wc9ExqcuVDKA41ZffVEgOI2nrXmXGm1up9h/IE8/PYHi0P/XZq6rByJWkHTqB500P0Mn7j/EPtzWTld/+hFf7+3n+1thzH5YcBrZ/8RmtZjxBk5yj6bRjTE1ih9qbQd3OnYip8Nbg32M4/c68oiYwajAPfT0X5QO76PXDa+euHnA+TgNFSd/ywotfUhJ8HZNfeo+BFVv5zyYdtz36EMXLZjP31af5acD9zHnrTjpddMDHidVYQn5lFX6DY2hz6lhZslKUe5idv2fToec07hx0Vtau5KIyO5Xk5BNE/XMc51ntrNZkdxUZJ/IoUY9lUeLLaKtP8Ot37/P80nd4XTOTd8acvU2AjOR2I0ng4eXVIOvm1T3QyBKOymIyjh8kJTCA8V1iz0q/k5HdLuxVmRxITiKpqgsPXrTvV6aq8BhF9n0Ej7sFZVExVs8Q2rSP4fCeXSTnjiTOR0vLiDAiAlRYPFTotJH0wFjHynsT3X8S054D1QcrKPaO5cb7RhAOyCW7+XynityQgdwsF1PsDKRdTAxBlaUcyC7khjatUQCqwDAiQsJpaTdhNzuwwV8Cjew2kZNxnJw8HyJGdCI8pOY2tttQSl4EqNsF/TVwyU6sVXr0VVYUmkCCAv0uvv6Y7MaRl0JyTiFS9CjaR4ScW6ZURVGBGx+fAPy1FwsdMmZDEQ6riZiQYIK1NUz0KdnJh3O/YGtS7sWPO69hPPpe80sGkJ1mcrYv4rW3tlOtG8Dj82dxa3yL2vXdXw6W/ST8KBPzUgeCFApAxmWt5uiK2Xy0vzv3fT6dAZrTdxQnxiI9puQPeMR4HduLKrGGBl6WbM7ak3FYqjiStBtPZTTXxtZiWoXTgb00n0L/toyrzZouSl8COw5l6rNt8e41gmv89RxZl8tBp8wNdOaGBz+g47W/c9QVf05r6rzcRqrzU8k44Eu7sZHoANntpDr/CDvW/MgBj2FMfeRm4tQAEi67FVPuUXYc2skPlsEsqDBhD9DW+aHl9DYBve/ofWqbgI70HzaaqeX/43hJGSban8mclRzVFGWnk12upGWnONoFqeo9tFiHQCPjdtow63M4nriOlT/uxEs7AJ3bTlFR0clD3A6sZen8/ttKNpxwc91jt9YwoO3AbDJxaGMCaz4ez5KzLujg7tdhNSvpOHISkQOqqcjbzopln/Ph8sO07DWavpf+Ac4VqUPTMejMl2wzG7DsXcPaDUvYrj6rIlH96WQxYgV8AXRhtO8YRSy7qM4so9IBZyb0y06sRhMV6dv5bf0e9hu0DPUHh9GI1VOLzwX3/5UpyjmC1seL9tGRf/3ODMdZtehdPlq2E2efu3js8YcY3zviPN0yEg6LAWNZNolrVpCYfAx1vxG4qywYzN5/tBaqU0k+1h7/AR0Jv+jiatWU5NswmsPoGNmGYL8a0hNaj+GN/zR+MkBTITvNlO77iJde+p5U3y5Mmf0y47o0oSADOI8lsk4B3U3FZOQHEOBlJWf7B7z8oYNx785nWlfvs24mqfxv/L18UmREF7aV238eyPOr5jG2SazYLCO77VQXFpKT/ANfrqpAoRqGj6uEUnMQoZoL3/FddjOFhQfJuWsGtZq3qVChDe3CdTeebM+5jS6sphLcngpMZvAKCSKm79halSVbDJRmH2Wv7GR0VRFZ+TKK6nQS1n7H6tT23PX8NG6OPl13C4UH1vDhI2/zi8mFKnAzT+8dz7RZU+lbx1RtVVg4EWFnT6WXUakD8P/LNgES+rRNfLPgRT4+0ZWb/vUcs+7qU++HvjoEmpP7JHz77st8dwCgJdizWPbMeJb95dhQ4gbeSZ/2NeVNKfBQeBDZdxjPvrGQ2zucrpaE2y3hdkkYMnfx3eeL+FEfxz0z3uTT2BWs/bno0qtfU00UHig6Dee262/hhZHRZ7LJJJcLSZbPyi4LokP3PnS/JoGtFUfJLh5FzKluNSzZ7PryMz5bto08wBvYveRVCjMnct/9UxnZ7gLfh2ygrDiYQHUYUS3Pc4zNTLXLC/wDcew/xtE9aQzoFkHUXzrQjZzYsJBPlv7C/gKAlpC4nNmJ2Ux87gHuHBdPIOCuLCK7T0t6xYVffBFPdzGFR41YNYPoFh9Ny6t7Oe5L47ahT/ov/37ta1Kcbbht9mwm9wlHd6VH/s/hJuXgVmytopD3zeXBpSbkwCC6DLmTOWvuok/wnxcAjWfazv9Q0eEzuq75mFuaRIA5zYW5fA+fTHiOnwC8/fHmN96ZdpDwB99h+YTYC7xPxmk3YixM5x+j67Y8gIfSh+B23Rk8SEH4RSfO/fXcZqOerJxjOFXeHF/xHJPec6KJiuHa0ROZ8/BwYs4pT0ub/mN5YL6alp/m0+/DRxnY0Au0yVZKSnJxujy44ZxtAjxQa8IIjYwj2h5JeEQw3g1wujoEGiUhMSOYsXgEM+p6VvnkALpLlpEkCVCh9Q9EUpnYeegYN7bpikalQHZXkJdfTEmWmeKD69mgH8HLbz7MNQFlHEh2Yj9f2ZILl0tCQnEyaMgybkkGyYXdZsNsdmA1GTFYnef98N5+QQSW2EjLOUT6oAg6+ihRyC6Ks9IwOiVad+7C6Y4jnzY9GDToGo5szudASiaDImNOLk+uiWHEjLmMuMQvSK7O4MQxC8ZWOjReDhwuJaqzWz9hA3jwhV5MfuIEvy7cSIWHzPnTv/yJv+VV3r/l1YuczU328b20cHYkROWJ0+nEU6k877iIKy+No2Y7rUcOoGeH1lyF2Zd147ZRfWI1b876iiMlkYyf+w6PDIrE/8qP/J9LziN5TyGqMYv5z1ND0XrW3P8i5R/nhHYbT7cAAAvKSURBVDaW8U0qyAAo0YYN4dnfE3n2Ut4m2zDrM0n4uT8DJjoxW0Dje2kTlxXeLWg35CFeH3JJbwPsmKvzyC7REn/f2yy5o2uNW3XLNhPGygoq/SOJaPBVQE9vE+CPsttDjPxTr6Nf9CAmzRzEpAY842UfppTdDqzlBeTlZXPCaCQzPQt9hQmfVvH0C49F8cOHLNmaTnFRMQUpGVQU5KNqF4BSZcRhPUx6ZjHFmSkczUlhn92KsboKi0uBp9ILhdtKVU4yx1OPkuLyJzA4HF+7ieyMDIozfufXTbv4ae1Rtq5cxqc/76Ok2oDZYsdiNmK0Ok9OLGsRx4DBIWiMP7F0xXayioopzj1IWnU1ZV5tOWd0wjOEroPHckOElsqdiRwqttdrF0dLURapLR1UhelJPnCUI/l/nYCg8AJnWSEl7cIJ6hNDeJ0fNwpI3e6HsjgQ46EdHM1Ip/p8h7lNHDuwg+LWPRg+qi8xoWLneADcdgzpPzLrmflsLQxj5Euv8fCQaAKaxMj/ueTKZBK3ezKoUye8a5OSD5QmbeDI9FFNMgmgThwWDAWH+K1XNNrKJFb8kn35zu00U52XTV6aHz2jo2oMMgAOQzn5jhzK/tGrXkkA5yO5TOQez8eY58Wwoc1qm4Dac+pT2bRsPq99kUCWPoMfX32c5+es4qgjnO43TuHeW9ty6JVJ3DZuHNMXriJFN4iebaLpPnAEA5UH+Wj6OG57/SM2VAQSm5HE8sXzWZOnI7JrD4LZy4eff8W6qjC6+QUR038AcZ0q+OaFe7jtzZ/JCO3NQw+M4M6HHmRC6zRWLVnGtm0FJK1bwZebUjm5rJw3bQdOZOq4a/Hc8BKTxo3jtme+Ym+Zjl4d/zoI7tmiJzffN56+UXoSN+4g12iv81L6Ci8f/NJ3snr57+TY/egc/afuM8lOVd7vrN1bRqt2vbm+R2StLtrz88TbX8+eFXP5rVTGOyzur91nkgND5lZ+TQmj76BbGd7pPONGf0eyE3Pmeua+NJ+thV70f/oVnhgdf/6dDC+5aDMGqwN3Q+3nLDupTtrFb1JX+ncOrmUGkZOs40lMHBDfMHVoEjzwcrtwbf2Gd17JYOD4eqyueUlkXMYK8lIzOKHpTNfo2vS5SVhNFXgYixkeV4/Vps9XG7cDfdohcstzCBlzw2XbJgBZaBCuqhPyrjX/k79Zt0vONkuNc47yLDk7L1cuNrsapfyzSU6rbClOlL/78gt59a4Tst7e6KdsHiSHbCrYIi+Yeo3ct2c/+V9f7JFLTc6GKlw2JL4vP/tNglxgccjnvYqOfyKPGf60/H26WXacc4BeXv/acHnYp3tks8Mty7Isu60VcnHOdvn9ezrKsW2nyV8dzZPLzM7zl3s2xy75xdiH5R9LHLLB6GiQT/Z3JLmdsrU6Tz685RP56bG95RH/mCf/mlMuG2zui7/RXSmf2PK9/PakhfKOcodssVzg+jJkyjuWzpWnPPZfeW/V2X+wy4VHf5LfmTNNfmlzzh/1cdnl6qwd8vZN/5N/zT3zouxw2GSLvXHuWacpZLmhHp0EcFBtsiPLagJ0jbCnzGVky93J1hxo2b4bXcL96tFyuorILixle/nkxadYts9F+0mzefu+IbT2VzfIyjKys4RVr97Igsj5rJ5yLS1869dCqtj8bybM3IHd8cfmKT1f+o4FY1vjdbGmzYE5dFgUzI9vxLPxFz8ev/dqatlcPq6qLBJXvMm/lxw585ouOp4bn5zL9H4Xya82ZLJ980oWpYQz7cYOVBaFcsuo+rVsJJeNqvStbN6TRE7L25nY9VRGj72Y7JJKqnz7Mrpr400oEIFGaMYk7GYjFpsNpwtUukB0PioUTjNGswW7rEan0+Ct9EQhu7CZDZhsEiqNP1pv5SVORHNjLz/I0rlPs2RrJa3ueIsFD1xHVEBDBBkZl92EPmER9z/+DTGv/8RrY6LQXamkgqMfcO39y1AwkY9/n3H1jNM0F6Ycdq/9mCffSyF+wBReWPDPeo/TVGdtYtWSF/l005//Ek3/G+7jydcad7dNEWiEZsxA4tJ3+PaX9Ww/6GD4y5/z6KjOKDN/4JNPPmZ9WR9mvPoUt8SHojRlsuWrt1m0Kp/4f73FY2O70UpT2yFKN/bqE6xZMIN3VuSi+8eLzHpkDHFBvhdZTLUWZDdOpwOn1UjGziW8NWsVx40deHr1Uu7pGlDrNbQEoakTKURCM+bHgHteo3NHf6bOqiYu3Bv9kS242oxi4oQUDs7eSHL63QxpUUFOnppet97LyF9fYeXRNO4e2pFWmtpsZerGYc5g7XuPMn9FJuWmKDq6jrP+2xw21rfBYS0hNXk/R/aXYnGcWk5f3ZXwUC8acPk1QbjiRKARmjkbJ5J/x957EIH6Qxi7XU/3iCCK9qbhIJ6oIDtpud507xJJoPE4WU4XUcFBaFS1GXWScVkz2Lz4Nd5dmUuZSQaySfg+m4TG+jiDY4jSenHBxSMEoRkSgUZo3txZHE2wEtGqgkKv6xncIoAArzx2HjFij+uBymSjbd94/HQ+VBzYT6ZdR9+2kWi9axNonGSs+ZCP1mfg8NQQGNT4U1UVXTvQSunVrHYXFYSaiEAjNGtSdjJJxZChCOC2ri3xC1Ahl6Sy/6AJbXsL2tadCVHrUMnVpB4+gDmsLz1jg9DVaqKrC138VGa+fTfumg9uEKrwOPzVYnBGuLqIQCM0YzIluckU2iV6jBhKXEgQ3oAh+wiHLBZKNJ2IaaVBqwKsKRzaZEbuEU9bf00t12/yJbxzV8JrPlAQhIsQLXSh+ZKNZB0/ijW0NwO7tiJI4wG4yc9IwuqM4LbBfWgTqEUBuLOOsMMq0zOuHQE+Yn0DQbicRKARmi97Kke3mNF06kGHFgEnl8dxZ3JkbzW2zkPpExuM/6n9PbJS92Bw9aJbOy2SS8LlurJVF4S/ExFohGbLnZ3MLqsv8Z1jCDm1R45UlMLeNDsde/YlKsj/VN9wGTkphdi6hxNoSifPaMIiZo8JwmUjAo3QbBUVZOIZ0p1uHVugO9UbZirJocq3Oz1iI/A/s/OYhCQHEu4upzKwJ11a+eMn1tQRhMtGrAwgCLUk2Q1UGiW8/fzwVf15s7BLLg2bsQKj5Eugzufia48JQjMnWjSCUEuGpI+Zfu/rfH+4DFN9xngkG4bydFa9OoZR//6KVIP1/PvXCcJVQqQ3C0ItBQx8hm/WNEBBFTt4f8o8NuabMPcRLRnh6idaNIJQE8mOUV9OWWkpRpsbqb6dzcEjeWntBlbOu5Fg/+a9nYQg1IYINIJQk8oEPn74bm4eMJx5m7OodFzpCglC8yK6zgShJi2G8fTXQXjfmkDPjoFo1SdflhxmjDYHUq2aOJ54a3V4e9U3iUAQmh8RaAShFqTURDbeEs/oIA2n4gyGw18xe+VeSo21aeJEcvOMZxnTPgAfsZSZ8DcjAo0g1EiirKCAnqGDz9leIKDPQ8zr89AVrJcgNA8i0AhCTaQy9u71oPvNgWjP2vby0rrOvPDRalGLrjPhb0gEGkGogVSyhwQpnrsCleB2I3l64qEAw+GlJ7vODLXpOmvPbU88wci2/nifnYLjOJnFJmZNC1czEWgEoQYVxxJI0vTm5uLt/HCiD//sG0mQ1pOAPv9iXp9/1aFEF5aKaiqrbbi3pZFWWEKkbxsC1J6itSNclUR6syDUoLI6B92vb/P4F5V07xyEv7a+o/npfDd5Ig+9tQcP9SbmTbmDT/ZUYHE2SHUFockRa50JgiAIjUq0aARBEIRGJQKNIAiC0KhEoBEEQRAalQg0giAIQqMSgUYQBEFoVCLQCIIgCI1KBBpBEAShUYlAIwiCIDQqEWgEQRCERiUCjSAIgtCoRKARBEEQGpUINIIgCEKjEoFGEARBaFQi0AiCIAiNSgQaQRAEoVGJHTYbnITTYsHiApWPLz7KJhLLZRc2swUnStQ+3qg8xV6OgiBcHk3kLng1KefgN+/z9NPzWXq47EpX5g/GZNbMfZpXFqwgqch+pWsjCMLfyP8DJPZxUCjVta8AAAAASUVORK5CYII="
    }
   },
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![image.png](attachment:image.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "class IterDataset(data.IterableDataset):\n",
    "    def __init__(self, data_generator):\n",
    "        super(IterDataset, self).__init__()\n",
    "        self.data_generator = data_generator\n",
    "\n",
    "    def __iter__(self):\n",
    "        return iter(self.data_generator())\n",
    "generator = Datagen(shape = (H,W),txt_path = './trainlist.txt')\n",
    "iter_dataset = IterDataset(generator)\n",
    "data_loader = data.DataLoader(iter_dataset, batch_size=batch_size)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.tensorboard import SummaryWriter\n",
    "# default `log_dir` is \"runs\" - we'll be more specific here\n",
    "writer = SummaryWriter('runs/vgg_16x16/')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 2, Batch:1193, loss:8.925530743091665        pixel_loss:4.0964155197143555, feature_loss:3.904881715774536 ,temp:4.310582160949707, shape:0.2874298393726349353"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[1;32mc:\\Users\\VINY\\Desktop\\Stabnet multigrid2-DeepStab Modded - Future Frames\\Stabnet_vgg19_16x16.ipynb Cell 12\u001b[0m line \u001b[0;36m6\n\u001b[0;32m      <a href='vscode-notebook-cell:/c%3A/Users/VINY/Desktop/Stabnet%20multigrid2-DeepStab%20Modded%20-%20Future%20Frames/Stabnet_vgg19_16x16.ipynb#X14sZmlsZQ%3D%3D?line=2'>3</a>\u001b[0m \u001b[39mfor\u001b[39;00m epoch \u001b[39min\u001b[39;00m \u001b[39mrange\u001b[39m(starting_epoch,EPOCHS):  \n\u001b[0;32m      <a href='vscode-notebook-cell:/c%3A/Users/VINY/Desktop/Stabnet%20multigrid2-DeepStab%20Modded%20-%20Future%20Frames/Stabnet_vgg19_16x16.ipynb#X14sZmlsZQ%3D%3D?line=3'>4</a>\u001b[0m     \u001b[39m# Generate the data for each iteration\u001b[39;00m\n\u001b[0;32m      <a href='vscode-notebook-cell:/c%3A/Users/VINY/Desktop/Stabnet%20multigrid2-DeepStab%20Modded%20-%20Future%20Frames/Stabnet_vgg19_16x16.ipynb#X14sZmlsZQ%3D%3D?line=4'>5</a>\u001b[0m     running_loss \u001b[39m=\u001b[39m \u001b[39m0\u001b[39m\n\u001b[1;32m----> <a href='vscode-notebook-cell:/c%3A/Users/VINY/Desktop/Stabnet%20multigrid2-DeepStab%20Modded%20-%20Future%20Frames/Stabnet_vgg19_16x16.ipynb#X14sZmlsZQ%3D%3D?line=5'>6</a>\u001b[0m     \u001b[39mfor\u001b[39;00m idx,batch \u001b[39min\u001b[39;00m \u001b[39menumerate\u001b[39m(data_loader):\n\u001b[0;32m      <a href='vscode-notebook-cell:/c%3A/Users/VINY/Desktop/Stabnet%20multigrid2-DeepStab%20Modded%20-%20Future%20Frames/Stabnet_vgg19_16x16.ipynb#X14sZmlsZQ%3D%3D?line=6'>7</a>\u001b[0m         start \u001b[39m=\u001b[39m time()\n\u001b[0;32m      <a href='vscode-notebook-cell:/c%3A/Users/VINY/Desktop/Stabnet%20multigrid2-DeepStab%20Modded%20-%20Future%20Frames/Stabnet_vgg19_16x16.ipynb#X14sZmlsZQ%3D%3D?line=7'>8</a>\u001b[0m         St, St_1, Igt , flow,features \u001b[39m=\u001b[39m batch\n",
      "File \u001b[1;32mc:\\Users\\VINY\\anaconda3\\envs\\DUTCode\\lib\\site-packages\\torch\\utils\\data\\dataloader.py:633\u001b[0m, in \u001b[0;36m_BaseDataLoaderIter.__next__\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m    630\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_sampler_iter \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[0;32m    631\u001b[0m     \u001b[39m# TODO(https://github.com/pytorch/pytorch/issues/76750)\u001b[39;00m\n\u001b[0;32m    632\u001b[0m     \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_reset()  \u001b[39m# type: ignore[call-arg]\u001b[39;00m\n\u001b[1;32m--> 633\u001b[0m data \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_next_data()\n\u001b[0;32m    634\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_num_yielded \u001b[39m+\u001b[39m\u001b[39m=\u001b[39m \u001b[39m1\u001b[39m\n\u001b[0;32m    635\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_dataset_kind \u001b[39m==\u001b[39m _DatasetKind\u001b[39m.\u001b[39mIterable \u001b[39mand\u001b[39;00m \\\n\u001b[0;32m    636\u001b[0m         \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_IterableDataset_len_called \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39mand\u001b[39;00m \\\n\u001b[0;32m    637\u001b[0m         \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_num_yielded \u001b[39m>\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_IterableDataset_len_called:\n",
      "File \u001b[1;32mc:\\Users\\VINY\\anaconda3\\envs\\DUTCode\\lib\\site-packages\\torch\\utils\\data\\dataloader.py:677\u001b[0m, in \u001b[0;36m_SingleProcessDataLoaderIter._next_data\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m    675\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m_next_data\u001b[39m(\u001b[39mself\u001b[39m):\n\u001b[0;32m    676\u001b[0m     index \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_next_index()  \u001b[39m# may raise StopIteration\u001b[39;00m\n\u001b[1;32m--> 677\u001b[0m     data \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_dataset_fetcher\u001b[39m.\u001b[39;49mfetch(index)  \u001b[39m# may raise StopIteration\u001b[39;00m\n\u001b[0;32m    678\u001b[0m     \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_pin_memory:\n\u001b[0;32m    679\u001b[0m         data \u001b[39m=\u001b[39m _utils\u001b[39m.\u001b[39mpin_memory\u001b[39m.\u001b[39mpin_memory(data, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_pin_memory_device)\n",
      "File \u001b[1;32mc:\\Users\\VINY\\anaconda3\\envs\\DUTCode\\lib\\site-packages\\torch\\utils\\data\\_utils\\fetch.py:32\u001b[0m, in \u001b[0;36m_IterableDatasetFetcher.fetch\u001b[1;34m(self, possibly_batched_index)\u001b[0m\n\u001b[0;32m     30\u001b[0m \u001b[39mfor\u001b[39;00m _ \u001b[39min\u001b[39;00m possibly_batched_index:\n\u001b[0;32m     31\u001b[0m     \u001b[39mtry\u001b[39;00m:\n\u001b[1;32m---> 32\u001b[0m         data\u001b[39m.\u001b[39mappend(\u001b[39mnext\u001b[39;49m(\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mdataset_iter))\n\u001b[0;32m     33\u001b[0m     \u001b[39mexcept\u001b[39;00m \u001b[39mStopIteration\u001b[39;00m:\n\u001b[0;32m     34\u001b[0m         \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mended \u001b[39m=\u001b[39m \u001b[39mTrue\u001b[39;00m\n",
      "File \u001b[1;32mc:\\Users\\VINY\\Desktop\\Stabnet multigrid2-DeepStab Modded - Future Frames\\datagen.py:42\u001b[0m, in \u001b[0;36mDatagen.__call__\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m     40\u001b[0m u_cap\u001b[39m.\u001b[39mset(cv2\u001b[39m.\u001b[39mCAP_PROP_POS_FRAMES, frame_idx)\n\u001b[0;32m     41\u001b[0m \u001b[39mfor\u001b[39;00m i,pos \u001b[39min\u001b[39;00m \u001b[39menumerate\u001b[39m(\u001b[39mrange\u001b[39m(frame_idx \u001b[39m-\u001b[39m SKIP, frame_idx \u001b[39m+\u001b[39m SKIP \u001b[39m+\u001b[39m \u001b[39m1\u001b[39m, SKIP \u001b[39m/\u001b[39m\u001b[39m/\u001b[39m \u001b[39m2\u001b[39m)):\n\u001b[1;32m---> 42\u001b[0m     u_cap\u001b[39m.\u001b[39;49mset(cv2\u001b[39m.\u001b[39;49mCAP_PROP_POS_FRAMES,pos \u001b[39m-\u001b[39;49m \u001b[39m1\u001b[39;49m)\n\u001b[0;32m     43\u001b[0m     _,temp \u001b[39m=\u001b[39m u_cap\u001b[39m.\u001b[39mread()\n\u001b[0;32m     44\u001b[0m     temp \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mpreprocess(temp)\n",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "cv2.namedWindow('window',cv2.WINDOW_NORMAL)\n",
    "# Training loop\n",
    "for epoch in range(starting_epoch,EPOCHS):  \n",
    "    # Generate the data for each iteration\n",
    "    running_loss = 0\n",
    "    for idx,batch in enumerate(data_loader):\n",
    "        start = time()\n",
    "        St, St_1, Igt , flow,features = batch\n",
    "        # Move the data to GPU if available\n",
    "        St = St.to(device)\n",
    "        St_1 = St_1.to(device)\n",
    "        Igt = Igt.to(device)\n",
    "        It = St[:,6:9,...].to(device)\n",
    "        It_1 = St_1[:,6:9,...].to(device)\n",
    "        flow = flow.to(device)\n",
    "        features = features.to(device)\n",
    "        # Forward pass through the Siamese Network\n",
    "        \n",
    "        transform0 = stabnet(St)\n",
    "        transform1 = stabnet(St_1)\n",
    "\n",
    "        warped0,warp_field = get_warp(transform0,It)\n",
    "        warped1,_ = get_warp(transform1,It_1)\n",
    "        # Compute the losses\n",
    "        #stability loss\n",
    "        pixel_loss = 10 * F.mse_loss(warped0, Igt)\n",
    "        feat_loss = feature_loss(features,warp_field)\n",
    "        stability_loss = pixel_loss + feat_loss\n",
    "        #shape_loss\n",
    "        sh_loss = shape_loss(transform0)\n",
    "        #temporal loss\n",
    "        warped2 = dense_warp(warped1, flow)\n",
    "        temp_loss =  10 * F.mse_loss(warped0,warped2)\n",
    "        # Perform backpropagation and update the model parameters\n",
    "        optimizer.zero_grad()\n",
    "        total_loss =  stability_loss + sh_loss + temp_loss\n",
    "        total_loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        \n",
    "        means = np.array([0.485, 0.456, 0.406],dtype = np.float32)\n",
    "        stds = np.array([0.229, 0.224, 0.225],dtype = np.float32)\n",
    "\n",
    "        img = warped0[0,...].detach().cpu().permute(1,2,0).numpy()\n",
    "        img *= stds\n",
    "        img += means\n",
    "        img = np.clip(img * 255.0,0,255).astype(np.uint8)\n",
    "\n",
    "        img1 = Igt[0,...].cpu().permute(1,2,0).numpy()\n",
    "        img1 *= stds\n",
    "        img1 += means\n",
    "        img1 = np.clip(img1 * 255.0,0,255).astype(np.uint8)\n",
    "        #draw features\n",
    "        stable_features = ((features[:, :, 0, :] + 1) / 2) * torch.tensor([255, 255], dtype=torch.float).to(device)\n",
    "        unstable_features = ((features[:, :, 1, :] + 1) / 2) * torch.tensor([255, 255], dtype=torch.float).to(device)\n",
    "        # Clip the features to the range [0, 255]\n",
    "        stable_features = torch.clamp(stable_features, min=0, max=255).cpu().numpy()\n",
    "        unstable_features = torch.clamp(unstable_features, min=0, max=255).cpu().numpy()\n",
    "        img = draw_features(img,unstable_features[0,...],color = (0,255,0))\n",
    "        img1 = draw_features(img1,stable_features[0,...],color = (0,0,255))\n",
    "        conc = cv2.hconcat([img,img1])\n",
    "        cv2.imshow('window',conc)\n",
    "        if cv2.waitKey(1) & 0xFF == ord('9'):\n",
    "            break\n",
    "        \n",
    "        running_loss += total_loss.item()\n",
    "        print(f\"\\rEpoch: {epoch}, Batch:{idx}, loss:{running_loss / (idx % 100 + 1)}\\\n",
    "        pixel_loss:{pixel_loss.item()}, feature_loss:{feat_loss.item()} ,temp:{temp_loss.item()}, shape:{sh_loss.item()}\",end = \"\")\n",
    "        if idx % 100 == 99:\n",
    "            writer.add_scalar('training_loss',\n",
    "                                running_loss / 100, \n",
    "                                epoch * 41328 // batch_size + idx)\n",
    "            running_loss = 0.0\n",
    "            # Get current date and time as a formatted string\n",
    "            current_datetime = datetime.datetime.now().strftime(\"%Y-%m-%d_%H-%M-%S\")\n",
    "\n",
    "            # Append the formatted date-time string to the model filename\n",
    "            model_path = os.path.join(ckpt_dir, f'stabnet_{current_datetime}.pth')\n",
    "            \n",
    "            torch.save({'model': stabnet.state_dict(),\n",
    "                        'optimizer' : optimizer.state_dict(),\n",
    "                        'epoch' : epoch}\n",
    "                        ,model_path)\n",
    "        del St, St_1, It, Igt, It_1,flow,features\n",
    "cv2.destroyAllWindows()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "DUTCode",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
